{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)\n",
    "\n",
    "<img src=\"img/tfidf_flow.png\" width=\"300\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "brew install apache-spark\n",
    "pip3 install findspark\n",
    "\"\"\"\n",
    "import tensorflow as tf\n",
    "import findspark\n",
    "findspark.init()\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.feature import CountVectorizer, IDF\n",
    "from pyspark.ml.classification import LogisticRegression\n",
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_idx2word(_index_from=3):\n",
    "    word2idx = tf.keras.datasets.imdb.get_word_index()\n",
    "    word2idx = {k:(v+_index_from) for k,v in word2idx.items()}\n",
    "    word2idx[\"<pad>\"] = 0\n",
    "    word2idx[\"<start>\"] = 1\n",
    "    word2idx[\"<unk>\"] = 2\n",
    "    idx2word = {idx: w for w, idx in word2idx.items()}\n",
    "    return idx2word\n",
    "\n",
    "\n",
    "def make_df(x, y):\n",
    "    return sess.createDataFrame(\n",
    "        [(int(y_), [idx2word[idx] for idx in x_]) for x_, y_ in zip(x, y)],\n",
    "        ['label', 'words'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Accuracy: 0.882\n"
     ]
    }
   ],
   "source": [
    "idx2word = get_idx2word()\n",
    "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=20000)\n",
    "\n",
    "sess = SparkSession.builder.appName('nlp').getOrCreate()\n",
    "\n",
    "pipeline = Pipeline(stages=[\n",
    "    CountVectorizer(inputCol='words',\n",
    "                    outputCol='tf'),\n",
    "    IDF(inputCol='tf',\n",
    "        outputCol='tfidf'),\n",
    "    LogisticRegression(featuresCol='tfidf',\n",
    "                       regParam=1.0),\n",
    "])\n",
    "\n",
    "df_train = make_df(X_train, y_train)\n",
    "df_test = make_df(X_test, y_test)\n",
    "\n",
    "prediction = pipeline.fit(df_train).transform(df_test)\n",
    "print(\"Testing Accuracy: %.3f\" % MulticlassClassificationEvaluator().evaluate(prediction))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
